import os 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.distributions.uniform import Uniform
import pickle
from collections import namedtuple
from itertools import chain, repeat
import numpy as np
from e2cnn import gspaces
from e2cnn import nn

class RCNNIT(torch.nn.Module):

    def __init__(self, n_feats = 48):
        super(RCNNIT, self).__init__()
        from e2cnn import nn
        # the model is equivariant under rotations by 45 (2pi/8) degrees, modelled by C8
        self.r2_act = gspaces.Rot2dOnR2(N=8)

        in_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.input_type = in_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        in_type = self.block3.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool2 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)
        
        in_type = self.block4.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=2)

        in_type = self.block6.out_type
        out_type = nn.FieldType(self.r2_act, n_feats*[self.r2_act.regular_repr])
        self.block7 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.block8 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=3, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        self.pool4 = nn.PointwiseAvgPoolAntialiased(out_type, sigma = 0.66, stride=1)

    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        from e2cnn import nn
        x = nn.GeometricTensor(input, self.input_type)

        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)

        x = self.block5(x)
        x = self.block6(x)
        x = self.pool3(x)

        x = self.block7(x)
        x = self.block8(x)
        x = self.pool4(x)

        x = x.tensor
        return x 
